load("@com_google_protobuf//bazel:upb_proto_library.bzl", "upb_c_proto_library")
load("@rules_cc//cc:cc_library.bzl", "cc_library")
load("@rules_python//python/uv:lock.bzl", uv_lock = "lock")
load("@rules_zig//zig:defs.bzl", "zig_binary", "zig_library", "zig_shared_library")
load("@zml//bazel:runfiles.bzl", "runfiles_to_default")

# A proxy PJRT Plugin that loads the Neuron PJRT Plugin
# and returns the instance from nested GetPjrtApi.
#
# Additionally, it provides a way to load implicit transitive dependencies
# of neuronx-cc (see add_needed of the patchelf target below).
zig_shared_library(
    name = "libpjrt_neuron",
    zigopts = ["-fno-stack-check"],
    main = "libpjrt_neuron.zig",
    visibility = ["@libpjrt_neuron//:__subpackages__"],
    deps = [
        ":libpython",
        "//stdx",
        "@rules_zig//zig/runfiles",
        "@xla//xla/pjrt/c:pjrt_c_api_hdrs",
    ],
)

zig_binary(
    name = "neuronx-cc",
    data = ["@neuron_py_deps//neuronx_cc"],
    linkopts = ["-Wl,-rpath,$$ORIGIN/../lib"],
    main = "neuronx-cc.zig",
    tags = ["manual"],
    deps = [":libpython"],
)

runfiles_to_default(
    name = "neuronx-cc_files",
    visibility = ["@libpjrt_neuron//:__subpackages__"],
    deps = [":neuronx-cc"],
)

cc_library(
    name = "libpython",
    hdrs = ["libpython.h"],
    deps = [
        "@rules_python//python/cc:current_py_cc_headers",
        "@rules_python//python/cc:current_py_cc_libs",
    ],
)

cc_library(
    name = "empty",
    defines = ["ZML_RUNTIME_NEURON_DISABLED"],
)

cc_library(
    name = "zmlxneuron",
    defines = ["ZML_RUNTIME_NEURON"],
)

cc_library(
    name = "libnrt_headers",
    hdrs = ["nrt.h"],
    deps = ["@libpjrt_neuron//:libnrt_headers"],
)

filegroup(
    name = "layers",
    srcs = [],
    visibility = ["//visibility:public"],
)

upb_c_proto_library(
    name = "xla_data_upb",
    deps = ["@xla//xla:xla_data_proto"],
)

upb_c_proto_library(
    name = "hlo_proto_upb",
    deps = ["@xla//xla/service:hlo_proto"],
)

zig_shared_library(
    name = "libneuronxla",
    zigopts = [
        "-fno-stack-check",
        "-fPIC",
    ],
    main = "libneuronxla.zig",
    shared_lib_name = "libneuronxla.so",
    visibility = ["@libpjrt_neuron//:__subpackages__"],
    deps = [
        ":hlo_proto_upb",
        ":libpython",
        ":xla_data_upb",
        "//stdx",
        "//upb",
    ],
)

uv_lock(
    name = "requirements",
    srcs = ["requirements.in"],
    out = "requirements.lock.txt",
    args = [
        "--emit-index-url",
        "--emit-find-links",
        "--index-strategy=unsafe-best-match",
        "--upgrade",
    ],
    python_version = "3.11",
    tags = ["manual"],
)

zig_library(
    name = "neuron",
    import_name = "runtimes/neuron",
    main = "neuron.zig",
    visibility = ["//visibility:public"],
    deps = [
        "//pjrt",
    ] + select({
        "//runtimes:neuron.enabled": [
            ":libnrt_headers",
            ":libpython",
            ":zmlxneuron",
            "//async",
            "//stdx",
            "@libpjrt_neuron",
            "@rules_zig//zig/runfiles",
        ],
        "//conditions:default": [":empty"],
    }),
)
